#' This is the companion code to the post
#' "Discrete Representation Learning with VQ-VAE and TensorFlow Probability"
#' on the TensorFlow for R blog.
#'
#' https://blogs.rstudio.com/tensorflow/posts/2019-01-24-vq-vae/

library(keras)
library(tensorflow)
library(tfprobability)
library(tfdatasets)

library(dplyr)
library(glue)

# curry has to be installed from github because CRAN version has no "set_defaults" function
if(!('devtools' %in% rownames(installed.packages()) )) {
  install.packages('devtools')
}
devtools::install_github('thomasp85/curry')

library(curry)

moving_averages <- tf$python$training$moving_averages


# Utilities --------------------------------------------------------

visualize_images <-
  function(dataset,
           epoch,
           reconstructed_images,
           random_images) {
    write_png(dataset, epoch, "reconstruction", reconstructed_images)
    write_png(dataset, epoch, "random", random_images)
    
  }

write_png <- function(dataset, epoch, desc, images) {
  png(paste0(dataset, "_epoch_", epoch, "_", desc, ".png"))
  par(mfcol = c(8, 8))
  par(mar = c(0.5, 0.5, 0.5, 0.5),
      xaxs = 'i',
      yaxs = 'i')
  for (i in 1:64) {
    img <- images[i, , , 1]
    img <- t(apply(img, 2, rev))
    image(
      1:28,
      1:28,
      img * 127.5 + 127.5,
      col = gray((0:255) / 255),
      xaxt = 'n',
      yaxt = 'n'
    )
  }
  dev.off()
  
}


# Setup and preprocessing -------------------------------------------------

np <- import("numpy")

# download from: https://github.com/rois-codh/kmnist via "download_data()" function
download_data = function(){
  if(!file.exists('kmnist-train-imgs.npz')) {
    download.file('http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-imgs.npz',
                  destfile = 'kmnist-train-imgs.npz')
  }
}
download_data()
kuzushiji <- np$load("kmnist-train-imgs.npz")
kuzushiji <- kuzushiji$get("arr_0")

train_images <- kuzushiji %>%
  k_expand_dims() %>%
  k_cast(dtype = "float32")
train_images <- train_images %>% `/`(255)

buffer_size <- 60000
batch_size <- 64
num_examples_to_generate <- batch_size

batches_per_epoch <- buffer_size / batch_size

train_dataset <- tensor_slices_dataset(train_images) %>%
  dataset_shuffle(buffer_size) %>%
  dataset_batch(batch_size, drop_remainder = TRUE)

# test
iter <- make_iterator_one_shot(train_dataset)
batch <-  iterator_get_next(iter)
batch %>% dim()

# Params ------------------------------------------------------------------

learning_rate <- 0.001
latent_size <- 1L
num_codes <- 64L
code_size <- 16L
base_depth <- 32
activation <- "elu"
beta <- 0.25
decay <- 0.99
input_shape <- c(28, 28, 1)

# Models -------------------------------------------------------------------

default_conv <-
  set_defaults(layer_conv_2d, list(padding = "same", activation = activation))
default_deconv <-
  set_defaults(layer_conv_2d_transpose,
               list(padding = "same", activation = activation))

# Encoder ------------------------------------------------------------------

encoder_model <- function(name = NULL,
                          code_size) {
  
  keras_model_custom(name = name, function(self) {
    self$conv1 <- default_conv(filters = base_depth, kernel_size = 5)
    self$conv2 <-
      default_conv(filters = base_depth,
                   kernel_size = 5,
                   strides = 2)
    self$conv3 <-
      default_conv(filters = 2 * base_depth, kernel_size = 5)
    self$conv4 <-
      default_conv(
        filters = 2 * base_depth,
        kernel_size = 5,
        strides = 2
      )
    self$conv5 <-
      default_conv(
        filters = 4 * latent_size,
        kernel_size = 7,
        padding = "valid"
      )
    self$flatten <- layer_flatten()
    self$dense <- layer_dense(units = latent_size * code_size)
    self$reshape <-
      layer_reshape(target_shape = c(latent_size, code_size))
    
    function (x, mask = NULL) {
      x %>%
        # output shape:  7 28 28 32
        self$conv1() %>%
        # output shape:  7 14 14 32
        self$conv2() %>%
        # output shape:  7 14 14 64
        self$conv3() %>%
        # output shape:  7 7 7 64
        self$conv4() %>%
        # output shape:  7 1 1 4
        self$conv5() %>%
        # output shape:  7 4
        self$flatten() %>%
        # output shape:  7 16
        self$dense() %>%
        # output shape:  7 1 16
        self$reshape()
    }
    
  })
}


# Decoder ------------------------------------------------------------------

decoder_model <- function(name = NULL,
                          input_size,
                          output_shape) {
  
  keras_model_custom(name = name, function(self) {
    self$reshape1 <- layer_reshape(target_shape = c(1, 1, input_size))
    self$deconv1 <-
      default_deconv(
        filters = 2 * base_depth,
        kernel_size = 7,
        padding = "valid"
      )
    self$deconv2 <-
      default_deconv(filters = 2 * base_depth, kernel_size = 5)
    self$deconv3 <-
      default_deconv(
        filters = 2 * base_depth,
        kernel_size = 5,
        strides = 2
      )
    self$deconv4 <-
      default_deconv(filters = base_depth, kernel_size = 5)
    self$deconv5 <-
      default_deconv(filters = base_depth,
                     kernel_size = 5,
                     strides = 2)
    self$deconv6 <-
      default_deconv(filters = base_depth, kernel_size = 5)
    self$conv1 <-
      default_conv(filters = output_shape[3],
                   kernel_size = 5,
                   activation = "linear")
    
    function (x, mask = NULL) {
      x <- x %>%
        # output shape:  7 1 1 16
        self$reshape1() %>%
        # output shape:  7 7 7 64
        self$deconv1() %>%
        # output shape:  7 7 7 64
        self$deconv2() %>%
        # output shape:  7 14 14 64
        self$deconv3() %>%
        # output shape:  7 14 14 32
        self$deconv4() %>%
        # output shape:  7 28 28 32
        self$deconv5() %>%
        # output shape:  7 28 28 32
        self$deconv6() %>%
        # output shape:  7 28 28 1
        self$conv1()
      tfd_independent(tfd_bernoulli(logits = x),
                      reinterpreted_batch_ndims = length(output_shape))
    }
  })
}

# Vector quantizer -------------------------------------------------------------------

vector_quantizer_model <- 
  function(name = NULL, num_codes, code_size) {
    
    keras_model_custom(name = name, function(self) {
      self$num_codes <- num_codes
      self$code_size <- code_size
      self$codebook <- tf$compat$v1$get_variable("codebook",
                                       shape = c(num_codes, code_size),
                                       dtype = tf$float32)
      self$ema_count <- tf$compat$v1$get_variable(
        name = "ema_count",
        shape = c(num_codes),
        initializer = tf$constant_initializer(0),
        trainable = FALSE
      )
      self$ema_means = tf$compat$v1$get_variable(
        name = "ema_means",
        initializer = self$codebook$initialized_value(),
        trainable = FALSE
      )
      
      function (x, mask = NULL) {

        # bs * 1 * num_codes
        distances <- tf$norm(tf$expand_dims(x, axis = 2L) -
                               tf$reshape(self$codebook,
                                          c(
                                            1L, 1L, self$num_codes, self$code_size
                                          )),
                             axis = 3L)
        
        # bs * 1
        assignments <- tf$argmin(distances, axis = 2L)
        
        # bs * 1 * num_codes
        one_hot_assignments <-
          tf$one_hot(assignments, depth = self$num_codes)
        
        # bs * 1 * code_size
        nearest_codebook_entries <- tf$reduce_sum(
          tf$expand_dims(one_hot_assignments,-1L) * # bs, 1, 64, 1
            tf$reshape(self$codebook, c(
              1L, 1L, self$num_codes, self$code_size
            )),
          axis = 2L # 1, 1, 64, 16
        )
        
        list(nearest_codebook_entries, one_hot_assignments)
      }
    })
  }


# Update codebook ------------------------------------------------------

update_ema <- function(vector_quantizer,
                       one_hot_assignments,
                       codes,
                       decay) {
  # shape = 64
  updated_ema_count <- moving_averages$assign_moving_average(
    vector_quantizer$ema_count,
    tf$reduce_sum(one_hot_assignments, axis = c(0L, 1L)),
    decay,
    zero_debias = FALSE
  )
  
  # 64 * 16
  updated_ema_means <- moving_averages$assign_moving_average(
    vector_quantizer$ema_means,
    # selects all assigned values (masking out the others) and sums them up over the batch
    # (will be divided by count later)
    tf$reduce_sum(
      tf$expand_dims(codes, 2L) *
        tf$expand_dims(one_hot_assignments, 3L),
      axis = c(0L, 1L)
    ),
    decay,
    zero_debias = FALSE
  )
  
  # Add small value to avoid dividing by zero
  updated_ema_count <- updated_ema_count + 1e-5
  updated_ema_means <-
    updated_ema_means / tf$expand_dims(updated_ema_count, axis = -1L)
  
  tf$compat$v1$assign(vector_quantizer$codebook, updated_ema_means)
}


# Training setup -----------------------------------------------------------

encoder <- encoder_model(code_size = code_size)
decoder <- decoder_model(input_size = latent_size * code_size,
                         output_shape = input_shape)

vector_quantizer <-
  vector_quantizer_model(num_codes = num_codes, code_size = code_size)

optimizer <- tf$optimizers$Adam(learning_rate = learning_rate)

checkpoint_dir <- "./vq_vae_checkpoints"

checkpoint_prefix <- file.path(checkpoint_dir, "ckpt")
checkpoint <-
  tf$train$Checkpoint(
    optimizer = optimizer,
    encoder = encoder,
    decoder = decoder,
    vector_quantizer_model = vector_quantizer
  )

checkpoint$save(file_prefix = checkpoint_prefix)

# Training loop -----------------------------------------------------------

num_epochs <- 20

for (epoch in seq_len(num_epochs)) {
  
  iter <- make_iterator_one_shot(train_dataset)
  
  total_loss <- 0
  reconstruction_loss_total <- 0
  commitment_loss_total <- 0
  prior_loss_total <- 0
  
  until_out_of_range({
    
    x <-  iterator_get_next(iter)
    
    with(tf$GradientTape(persistent = TRUE) %as% tape, {
      
      codes <- encoder(x)
      c(nearest_codebook_entries, one_hot_assignments) %<-% vector_quantizer(codes)
      codes_straight_through <- codes + tf$stop_gradient(nearest_codebook_entries - codes)
      decoder_distribution <- decoder(codes_straight_through)
      
      reconstruction_loss <-
        -tf$reduce_mean(decoder_distribution$log_prob(x))
      
      commitment_loss <- tf$reduce_mean(tf$square(codes - tf$stop_gradient(nearest_codebook_entries)))
      
      prior_dist <- tfd_multinomial(total_count = 1,
                                    logits = tf$zeros(c(latent_size, num_codes)))
      prior_loss <- -tf$reduce_mean(tf$reduce_sum(prior_dist$log_prob(one_hot_assignments), 1L))
      
      loss <-
        reconstruction_loss + beta * commitment_loss + prior_loss
      
    })
    
    encoder_gradients <- tape$gradient(loss, encoder$variables)
    decoder_gradients <- tape$gradient(loss, decoder$variables)
    
    optimizer$apply_gradients(purrr::transpose(list(
      encoder_gradients, encoder$variables
    )))
    optimizer$apply_gradients(purrr::transpose(list(
      decoder_gradients, decoder$variables
    )))
    
    update_ema(vector_quantizer,
               one_hot_assignments,
               codes,
               decay)
    
    total_loss <- total_loss + loss
    reconstruction_loss_total <-
      reconstruction_loss_total + reconstruction_loss
    commitment_loss_total <- commitment_loss_total + commitment_loss
    prior_loss_total <- prior_loss_total + prior_loss
    
  })
  
  checkpoint$save(file_prefix = checkpoint_prefix)
  
  cat(
    glue(
      "Loss (epoch): {epoch}:",
      "  {(as.numeric(total_loss)/trunc(buffer_size/batch_size)) %>% round(4)} loss",
      "  {(as.numeric(reconstruction_loss_total)/trunc(buffer_size/batch_size)) %>% round(4)} reconstruction_loss",
      "  {(as.numeric(commitment_loss_total)/trunc(buffer_size/batch_size)) %>% round(4)} commitment_loss",
      "  {(as.numeric(prior_loss_total)/trunc(buffer_size/batch_size)) %>% round(4)} prior_loss",
      
    ),
    "\n"
  )
  
  # display example images (choose your frequency)
  if (TRUE) {
    reconstructed_images <- decoder_distribution$mean()
    # (64, 1, 16)
    prior_samples <- tf$reduce_sum(
      # selects one of the codes (masking out 63 of 64 codes)
      # (bs, 1, 64, 1)
      tf$expand_dims(prior_dist$sample(num_examples_to_generate),-1L) *
        # (1, 1, 64, 16)
        tf$reshape(vector_quantizer$codebook,
                   c(1L, 1L, num_codes, code_size)),
      axis = 2L
    )
    decoded_distribution_given_random_prior <-
      decoder(prior_samples)
    random_images <- decoded_distribution_given_random_prior$mean()
    visualize_images("k", epoch, reconstructed_images, random_images)
  }
}
